# -*- coding: utf-8 -*-
"""
This file define a function that takes a text string as inputs, and produce a predicted probability of this post talking about protest, based on stage-1 or stage-2.

"""



#%%

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import pandas as pd
import csv
import codecs
import numpy as np

from CASM_c1_deep_text import *
from CASM_c2_deep_text import *
from keras_self_attention import SeqWeightedAttention

def predict_string(texts, type = "c1"):
	""" take a list of (segmented) text as input
	return a list of predicted probabilities

	Args:
		texts (list of string): list of text strings to predict
		type (str): 
		- c1: classifier 1 only (grievances vs non-grievances), deep learning based classifier
		- c2: classifier 2 only (protest vs. grievance), deep learning based
		- c1c2: combine the predictions of the first and the second stage together
	Returns:
		list of predicted probabilities

	"""

	if type == "c1":
		prob_vector1 = [predict_text_deep(x) for x in texts]		
		return prob_vector1


	elif type == "c2":
		# prob_vector1 = [predict_text_deep(x) for x in texts]
		# prob = [(prob_vector1[i] + p2 * 0.95) / 1.95 for i, p2 in enumerate(prob_vector2)]

		prob_vector2 = [predict_text_deep_protest_vs_grievance(x) for x in texts]
		prob_vector = np.array(prob_vector2)

		return prob
	else:
		print "not support classifier type"

if __name__ == '__main__':

	wd = "你 能 找出 照片 中 的 什么 让 男子 发现 妻子 出轨 吗   一般 人 是 看不出 的     详见 下图   灯钾     一千多年 前 就 马丁 靴   轻 摇滚 的 女学   内置 的 太阳能 电池板 假发 不错     要 不要 试去 扫 这些 手无寸铁 的 老椎病   每天 早上 坚持 两".decode('utf-8')
	wdl = [wd]
	print predict_string(wdl, "c1")
	print predict_string(wdl, "c2")


